import os
import json
import ssl
import pandas as pd
import numpy as np
import cv2
import sys
# ssl._create_default_https_context = ssl._create_unverified_context
from tensorflow import keras as K
import tensorflow as tf
import matplotlib.pyplot as plt
import traceback

jsonfile = ""
imagefolder = ""
labeltable = ""
trainingfolder = ""
# modelname=""
# weights=""
pretrainmodel = ""
kerasmodel = ""
# appfolder=""
imgnames = ""
labels = ""
cachefolder = ""
traininglabel = ""

batchsize = ""
trepoch = ""
optimiset = ""
lossfunc = ""
learningrate = ""

tf_finetune = ""
freezemodel = ""
dropout_rate = ""
predictclass = ""

k_inputshape = 0


def truepos(y_true, y_pred):
    tp = K.backend.sum(K.backend.round(K.backend.clip(y_true * y_pred, 0, 1)))
    return tp


def trueneg(y_true, y_pred):
    y_true = tf.ones_like(y_true)
    y_pred_neg = 1 - K.backend.round(K.backend.clip(y_pred, 0, 1))
    y_neg = 1 - K.backend.round(K.backend.clip(y_true, 0, 1))
    tn = K.backend.sum(K.backend.round(K.backend.clip(y_pred_neg * y_neg, 0, 1)))
    return tn


def falsepos(y_true, y_pred):
    y_true = tf.ones_like(y_true)
    y_neg = 1 - K.backend.round(K.backend.clip(y_true, 0, 1))
    y_pred_pos = K.backend.round(K.backend.clip(y_pred, 0, 1))
    fp = K.backend.sum(y_neg * y_pred_pos)
    return fp


def falseneg(y_true, y_pred):
    y_true = tf.ones_like(y_true)
    y_pos = K.backend.round(K.backend.clip(y_true, 0, 1))
    y_pred_neg = 1 - K.backend.round(K.backend.clip(y_pred, 0, 1))
    fn = K.backend.sum(y_pos * y_pred_neg)
    return fn


def recall(y_true, y_pred):
    y_true = K.backend.argmax(y_true)
    y_pred = K.backend.argmax(y_pred)
    true_positives = K.backend.cast(K.backend.sum(K.backend.round(K.backend.clip(y_true * y_pred, 0, 1))),'float32')
    all_positives = K.backend.cast(K.backend.sum(K.backend.round(K.backend.clip(y_true, 0, 1))),'float32')
    #
    recall = true_positives / (all_positives + K.backend.epsilon())
    return recall


def precision(y_true, y_pred):
    y_true = K.backend.argmax(y_true)
    y_pred = K.backend.argmax(y_pred)
    true_positives = K.backend.cast(K.backend.sum(K.backend.round(K.backend.clip(y_true * y_pred, 0, 1))),'float32')

    predicted_positives = K.backend.cast(K.backend.sum(K.backend.round(K.backend.clip(y_pred, 0, 1))),'float32')
    precision = true_positives / (predicted_positives + K.backend.epsilon())
    return precision


def f1_score(y_true, y_pred):
    prec = precision(y_true, y_pred)
    rec = recall(y_true, y_pred)
    return 2 * ((prec * rec) / (prec + rec + K.backend.epsilon()))


import scipy
import matplotlib.pyplot as plt
from PIL import Image as im
from scipy.ndimage import rotate


def translate(img, shift=50, direction='right', roll=True):
    assert direction in ['right', 'left', 'down', 'up'], 'Directions should be top|up|left|right'
    img = img.copy()
    shift = int(img.shape[0] / 3)
    if direction == 'right':
        right_slice = img[:, -shift:].copy()
        img[:, shift:] = img[:, :-shift]
        if roll:
            img[:, :shift] = np.fliplr(right_slice)
    if direction == 'left':
        left_slice = img[:, :shift].copy()
        img[:, :-shift] = img[:, shift:]
        if roll:
            img[:, -shift:] = left_slice
    if direction == 'down':
        down_slice = img[-shift:, :].copy()
        img[shift:, :] = img[:-shift, :]
        if roll:
            img[:shift, :] = down_slice
    if direction == 'up':
        upper_slice = img[:shift, :].copy()
        img[:-shift, :] = img[shift:, :]
        if roll:
            img[-shift:, :] = upper_slice
    return img


def random_crop(img, crop_size=(10, 10)):
    assert crop_size[0] <= img.shape[0] and crop_size[1] <= img.shape[1], "Crop size should be less than image size"
    img = img.copy()
    w, h = img.shape[:2]
    x, y = np.random.randint(h - crop_size[0]), np.random.randint(w - crop_size[1])
    img = img[y:y + crop_size[0], x:x + crop_size[1]]
    return img


# plot_grid([translate(img, direction='up', shift=20),
#           translate(img, direction='down', shift=20),
#           translate(img, direction='left', shift=20),
#           translate(img, direction='right', shift=20)],
#           1, 4, figsize=(10, 5))

def rotate_img(img, angle, bg_patch=(5, 5)):
    assert len(img.shape) <= 3, "Incorrect image shape"
    rgb = len(img.shape) == 3
    if rgb:
        bg_color = np.mean(img[:bg_patch[0], :bg_patch[1], :], axis=(0, 1))
    else:
        bg_color = np.mean(img[:bg_patch[0], :bg_patch[1]])
    img = rotate(img, angle, reshape=False)
    mask = [img <= 0, np.any(img <= 0, axis=-1)][rgb]
    img[mask] = bg_color
    return img


def distort(img, orientation='horizontal', func=np.sin, x_scale=0.05, y_scale=5):
    assert orientation[:3] in ['hor', 'ver'], "dist_orient should be 'horizontal'|'vertical'"
    assert func in [np.sin, np.cos], "supported functions are np.sin and np.cos"
    assert 0.00 <= x_scale <= 0.1, "x_scale should be in [0.0, 0.1]"
    assert 0 <= y_scale <= min(img.shape[0], img.shape[1]), "y_scale should be less then image size"
    img_dist = img.copy()

    def shift(x):
        return int(y_scale * func(np.pi * x * x_scale))

    for c in range(3):
        for i in range(img.shape[orientation.startswith('ver')]):
            if orientation.startswith('ver'):
                img_dist[:, i, c] = np.roll(img[:, i, c], shift(i))
            else:
                img_dist[i, :, c] = np.roll(img[i, :, c], shift(i))

    return img_dist


def imgaugmentation(img):
    inputimg = np.array(img)
    augimgs = []
    imgwidth = inputimg.shape[1]
    imgheight = inputimg.shape[0]
    if img_transfer == "true":
        augimgs.append(translate(img, direction='up', shift=20))
        augimgs.append(translate(img, direction='down', shift=20))
        augimgs.append(translate(img, direction='left', shift=20))
        augimgs.append(translate(img, direction='right', shift=20))
    if img_random == "true":
        cropsize = (int(imgheight / 3), int(imgwidth / 3))
        augimgs.append(
            cv2.resize(random_crop(img, crop_size=cropsize), (imgwidth, imgheight), interpolation=cv2.INTER_NEAREST))
        augimgs.append(
            cv2.resize(random_crop(img, crop_size=cropsize), (imgwidth, imgheight), interpolation=cv2.INTER_NEAREST))
        augimgs.append(
            cv2.resize(random_crop(img, crop_size=cropsize), (imgwidth, imgheight), interpolation=cv2.INTER_NEAREST))
        augimgs.append(
            cv2.resize(random_crop(img, crop_size=cropsize), (imgwidth, imgheight), interpolation=cv2.INTER_NEAREST))
    if img_rotate == "true":
        for i in range(1, 12):
            augimgs.append(rotate_img(img, i * 30))
    if img_distort == "true":
        augimgs.append(distort(img, 'hor', x_scale=0.01, y_scale=2))
        augimgs.append(distort(img, 'ver', x_scale=0.02, y_scale=4))
        augimgs.append(distort(img, 'hor', x_scale=0.03, y_scale=6))
        augimgs.append(distort(img, 'ver', x_scale=0.04, y_scale=8))
    return augimgs


def gettrainlabelfile():
    files = os.listdir(trainingfolder)
    print("training folder files", files)
    pngfilehead = ""
    csvfileexp = ""
    for f in files:
        fstruc = os.path.splitext(f)
        print("fstruc", fstruc)
        if fstruc[1] == ".png":
            pngfilehead = fstruc[0]
            print("pngfilehead", pngfilehead)
            dashind = pngfilehead.rfind("_")
            csvfileexp = pngfilehead[:dashind] + "_outputlabel.csv"
            if csvfileexp in files:
                break

    return csvfileexp


def optimizers(opname, trainingrate):
    if len(opname) <= 1:
        raise Exception("No optimizaer name")
    else:
        # "SGD","RMSprop","Adam","Adadelta","Adagrad","Adamax","Nadam","Ftrl"
        if opname == "SGD":
            optimizer = tf.keras.optimizers.SGD(learning_rate=float(trainingrate))
            return optimizer
        if opname == "RMSprop":
            optmizer = tf.keras.optimizers.RMSprop(learning_rate=float(trainingrate))
            return optmizer
        if opname == "Adam":
            optimizer = tf.keras.optimizers.Adam(learning_rate=float(trainingrate))
            return optimizer
        if opname == "Adadelta":
            optimizer = tf.keras.optimizers.Adadelta(learning_rate=float(trainingrate))
            return optimizer
        if opname == "Adagrad":
            optimizer = tf.keras.optimizers.Adagrad(learning_rate=float(trainingrate))
            return optimizer
        if opname == "Adamax":
            optimizer = tf.keras.optimizers.Adamax(learning_rate=float(trainingrate))
            return optimizer
        if opname == "Nadam":
            optimizer = tf.keras.optimizers.Nadam(learning_rate=float(trainingrate))
            return optimizer
        if opname == "Ftrl":
            optimizer = tf.keras.optimizers.Ftrl(learning_rate=float(trainingrate))
            return optimizer


def train():
    global traininglabel
    if len(trainingfolder) == 1 or len(traininglabel) == 1:
        print('No training data, going to predict test set directly.')
        return
    x_train = []
    y_train = []
    x_val = []
    y_val = []
    #    trainlablefile=gettrainlabelfile()
    #    print("Going to load images from training labelfile:",trainlablefile)
    traininglabelfd = pd.read_csv(traininglabel)
    trlabels = traininglabelfd["label"]
    trnames = traininglabelfd["filename"]
    trlabels = np.array(trlabels)
    trnames = np.array(trnames)
    trnames = list(trnames)
    neglabels = np.where(trlabels == 0)
    neglabels = list(neglabels[0])
    poslabels = np.where(trlabels == 1)
    poslabels = list(poslabels[0])
    print("neglabels", neglabels, "poslabels", poslabels)
    # 60 : 40 = train : val
    valneglen = int(len(neglabels) * (1 - trainratio))
    valposlen = int(len(poslabels) * (1 - trainratio))
    print("valneglen", valneglen, "valposlen", valposlen)
    trneglen = len(neglabels) - valneglen
    trposlen = len(poslabels) - valposlen
    print("trneglen", trneglen, "trposlen", trposlen)
    trposlabels = poslabels[:trposlen]
    print("trposlabels", trposlabels)
    valposlabels = poslabels[trposlen:]
    print("valposlabels", valposlabels)
    trneglabels = neglabels[:trneglen]
    print("trneglabels", trneglabels)
    valneglabels = neglabels[trneglen:]
    print("valneglabels", valneglabels)
    trainfiles = open(os.path.join(imagefolder, "traindata.txt"), "w")
    validfiles = open(os.path.join(imagefolder, "validdata.txt"), "w")
    for i in range(len(trnames)):
        imagepath = os.path.join(trainingfolder, trnames[i])
        image = cv2.imread(imagepath, cv2.COLOR_BGR2RGB)

        '''add black to make image shape to square'''
        # TODO: Add condition to trigger makeup img
        print('image shape =',image.shape)
        height,width,channel=image.shape
        if max(height/width,width/height)>1.05:
            if height>width:
                makeupimg=np.zeros((height,height,channel))
                extralen=int((height-width)/2)
                makeupimg[:,extralen:extralen+width,:]=makeupimg[:,extralen:extralen+width,:]+image
            else:
                makeupimg=np.zeros((width,width,channel))
                extralen=int((width-height)/2)
                makeupimg[extralen:extralen+height,:,:]=makeupimg[extralen:extralen+height,:,:]+image
        image=makeupimg.copy()
        print('image shape 2 =',image.shape)
        '''end make image shape to square'''
        augimgs = imgaugmentation(image)
        image = cv2.resize(image, (k_inputshape[1], k_inputshape[2]), interpolation=cv2.INTER_NEAREST)
        image = np.array(image)
        image = image.astype("float32")
        image /= 255
        if i in trposlabels:
            x_train.append(image)
            y_train.append([0, 1])
            trainfiles.write(imagepath)
        if i in trneglabels:
            x_train.append(image)
            y_train.append([1, 0])
            trainfiles.write(imagepath)
        try:
            if i in valposlabels:
                x_val.append(image)
                y_val.append([0, 1])
                validfiles.write(imagepath)
        except:
            if i == valposlabels:
                x_val.append(image)
                y_val.append([0, 1])
                validfiles.write(imagepath)
        try:
            if i in valneglabels:
                x_val.append(image)
                y_val.append([1, 0])
                validfiles.write(imagepath)
        except:
            if i == valneglabels:
                x_val.append(image)
                y_val.append([1, 0])
                validfiles.write(imagepath)
        if len(augimgs) > 0:
            for ele in augimgs:
                # image = im.fromarray(ele)
                image = cv2.resize(ele, (k_inputshape[1], k_inputshape[2]), interpolation=cv2.INTER_NEAREST)
                image = np.array(image)
                image = image.astype("float32")
                image /= 255
                if i in trposlabels:
                    x_train.append(image)
                    y_train.append([0, 1])
                if i in trneglabels:
                    x_train.append(image)
                    y_train.append([1, 0])
                # try:
                #     if i in valposlabels:
                #         x_val.append(image)
                #         y_val.append([0,1])
                # except:
                #     if i==valposlabels:
                #         x_val.append(image)
                #         y_val.append([0,1])
                # try:
                #     if i in valneglabels:
                #         x_val.append(image)
                #         y_val.append([1,0])
                # except:
                #     if i==valneglabels:
                #         x_val.append(image)
                #         y_val.append([1,0])
    trainfiles.close()
    validfiles.close()
    x_train = np.array(x_train)
    y_train = np.array(y_train)
    x_val = np.array(x_val)
    y_val = np.array(y_val)
    global kerasmodel

    checkpoint_filepath = imagefolder + "/" + "epoch{epoch:02d}-val_acc{val_accuracy:.2f}-val_loss{val_loss:.2f}-.hdf5"
    # cp_cb=K.callbacks.ModelCheckpoint(filepath=checkpoint_filepath,
    # monitor='val_loss',verbose=1,save_best_only=False,
    # save_weights_only=False, mode='auto', save_freq='epoch',
    # options=None, initial_value_threshold=None)
    earlystop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
    checkpoint = K.callbacks.ModelCheckpoint(filepath=checkpoint_filepath,
                                             save_best_only=False, monitor='val_loss', mode='auto',
                                             verbose=1)
    # trainlogger=K.callbacks.CSVLogger('traininglog.csv',separator=",", append=False)
    # if type(pretrainmodel)==type(""):
    #     kerasmodel.compile(optimizer=optimiset,loss=lossfunc,metrics=["accuracy"])
    # else:
    # optimizer = tf.keras.optimizers.Adam(learning_rate=1e-2)

    optimizer = optimizers(optimiset, learningrate)
    kerasmodel.compile(optimizer=optimizer, loss=lossfunc, metrics=["accuracy", precision, recall, f1_score])

    history = kerasmodel.fit(x_train, y_train, batch_size=batchsize,
                             epochs=trepoch, validation_data=(x_val, y_val), callbacks=[earlystop, checkpoint])

    # validate()

    fig, axs = plt.subplots(2)
    x = range(1, len(history.history['accuracy']) + 1)

    fig.suptitle("Metric results for training\n val_neg:" + str(valneglen) + " val_pos:" + str(valposlen))
    axs[0].plot(x, history.history['accuracy'], marker='o')
    axs[0].plot(x, history.history['val_accuracy'], marker='^')
    axs[0].set_title("Model accuracy")
    axs[0].set(xlabel="Epoch", ylabel="Accuracy")
    axs[0].set_xticks(range(1, trepoch + 1, 1))
    # if type(pretrainmodel)==type(""):
    #     print('in type 1')
    #     axs[1].plot(x,history.history['loss'],marker='o')
    #     axs[1].plot(x,history.history['val_loss'],marker='^')
    #     axs[1].set_title("Model Losses")
    #     axs[1].set(xlabel="Epoch", ylabel="Losses")
    #     axs[1].set_xticks(range(1, trepoch+1, 1))
    # else:
    print(history.history.keys)
    # trainrecall=np.array(history.history['recall'])
    # trainprrecision=np.array(history.history['precision'])
    # # traintp=np.array(history.history['TruePositives'])
    # train_f1=2 * np.divide(np.multiply(trainrecall,trainprrecision),(trainprrecision + trainrecall + K.backend.epsilon()))
    # val_recall=np.array(history.history['val_recall'])
    # val_precision=np.array(history.history['val_precision'])
    # val_f1=2 * np.divide(np.multiply(val_recall,val_precision),(val_recall + val_precision + K.backend.epsilon()))
    # print(trainrecall,trainprrecision,train_f1,val_recall,val_precision,val_f1)
    train_f1=np.array(history.history['f1_score']).round(4)
    val_f1=np.array(history.history['val_f1_score']).round(4)
    print(list(train_f1))
    print(list(val_f1))
    axs[1].plot(x, list(train_f1), marker='o')
    axs[1].plot(x, list(val_f1), marker='^')
    axs[1].set_title("Model F1 scores")
    axs[1].set(xlabel="Epoch", ylabel="F1_score")
    axs[1].set_xticks(range(1, trepoch + 1, 1))
    # plt.plot(history.history['accuracy'])
    # plt.plot(history.history['val_accuracy'])
    # plt.title('Model accuracy')
    # plt.ylabel('Accuracy')
    # plt.xlabel('Epoch')
    # plt.legend(['Train', 'Test'], loc='upper left')
    axs[0].legend(['Train', 'Val'], loc='upper left')
    axs[1].legend(['Train', 'Val'], loc='upper left')
    # plt.show()
    fig.tight_layout(pad=1.0)
    plt.savefig(imagefolder + "/NNMetrics.png", dpi=100)
    plt.switch_backend('Agg')
    print('finished plot.')
    plt.clf()


def validate():
    images = []
    # imgnames=sorted(imgnames)
    # print(imgnames)
    print('imgnames length', len(imgnames))
    for f in imgnames:
        try:
            imgpath = os.path.join(imagefolder, f)
            print(imgpath)
            image = cv2.imread(imgpath, cv2.COLOR_BGR2RGB)
            height, width, channel = image.shape
            if max(height / width, width / height) > 1.05:
                if height > width:
                    makeupimg = np.zeros((height, height, channel))
                    extralen = int((height - width) / 2)
                    makeupimg[:, extralen:extralen + width, :] = makeupimg[:, extralen:extralen + width, :] + image
                else:
                    makeupimg = np.zeros((width, width, channel))
                    extralen = int((width - height) / 2)
                    makeupimg[extralen:extralen + height, :, :] = makeupimg[extralen:extralen + height, :, :] + image
                image = makeupimg.copy()
            image = cv2.resize(image, (k_inputshape[1], k_inputshape[2]), interpolation=cv2.INTER_NEAREST)
            image = np.array(image)
            image = image.astype("float32")
            image /= 255
            images.append(image)
        except:
            print('cv2 error on ', f)
            print(traceback.format_exc())
            continue
    print('finish preprocess validation image')
    images = np.array(images)
    print(images.shape)
    yhat = kerasmodel.predict(images)
    yhatlabels = list(np.argmax(yhat, axis=1))
    print(yhat, yhatlabels, type(yhatlabels))
    df = pd.read_csv(os.path.join(imagefolder, labelfilename))
    filepath = os.path.join(imagefolder, labelfilename)
    df.insert(5, "predict", yhatlabels, True)
    negconf = list(yhat[:, 0])
    posconf = list(yhat[:, 1])
    df.insert(6, "confidence_neg", negconf, True)
    df.insert(7, "confidence_pos", posconf, True)
    df.to_csv(filepath)
    model_json = kerasmodel.to_json()
    with open(os.path.join(imagefolder, "my_model.json"), "w") as json_file:
        json_file.write(model_json)
    json_file.close()
    weightfilepath = imagefolder + "/" + "testset_weights.hdf5"
    kerasmodel.save(weightfilepath)
    # kerasmodel.save()


def main():
    jsonfile = sys.argv[1]
    global imagefolder
    imagefolder = sys.argv[2]  # validaiton folder
    global trainingfolder
    trainingfolder = sys.argv[3]  # training folder
    # global modelname
    # modelname=sys.argv[4]
    # global weights
    # weights=sys.argv[5]
    global batchsize
    batchsize = int(sys.argv[4])
    global trepoch
    trepoch = int(sys.argv[5])
    global optimiset
    optimiset = sys.argv[6]
    global lossfunc
    lossfunc = sys.argv[7]
    # global appfolder
    # appfolder=sys.argv[10]
    global labelfilename
    labelfilename = sys.argv[8]
    global trainratio
    trainratio = float(sys.argv[9])
    global img_transfer
    img_transfer = sys.argv[10]
    global img_random
    img_random = sys.argv[11]
    global img_rotate
    img_rotate = sys.argv[12]
    global img_distort
    img_distort = sys.argv[13]
    global traininglabel
    traininglabel = sys.argv[14]
    global tf_finetune, freezemodel, dropout_rate, predictclass
    tf_finetune = sys.argv[15]
    freezemodel = sys.argv[16]
    dropout_rate = sys.argv[17]
    predictclass = sys.argv[18]
    global learningrate
    learningrate = sys.argv[19]
    print("model file:", jsonfile, "validation data folder:", imagefolder, "validation data label:", labelfilename,
          "training dataset folder", trainingfolder, "training label doc", traininglabel, "batchsize", batchsize,
          "training epoch", trepoch, "trainratio", trainratio, img_transfer, img_random, img_rotate, img_distort,
          "tf_finetune", tf_finetune, "freezemodel", freezemodel, "droupout_rate", dropout_rate, "predictclass",
          predictclass, "learningrate", learningrate)
    global imgnames, labels, labeltable
    print(imagefolder, labelfilename)
    labeltable = pd.read_csv(os.path.join(imagefolder, labelfilename))
    imgnames = labeltable["filename"]
    imgnames = np.array(imgnames)
    imgnames = list(imgnames)
    labels = labeltable["label"]
    # print(labeltable,imgnames,labels)

    # f=open(jsonfile,'r')
    # data=json.load(f)
    # r=json.dumps(data)
    global kerasmodel
    # kerasmodel=K.models.model_from_json(r)
    kerasmodel = K.models.load_model(imagefolder)
    print(kerasmodel.summary)

    global k_inputshape
    k_inputshape = kerasmodel.input_shape
    # input = kerasmodel.input
    # global cachefolder
    # cachefolder=os.path.dirname(jsonfile)
    # print("cachefolder",cachefolder)
    # global pretrainmodel

    #     top_dropout_rate = float(dropout_rate)
    #
    #     if modelname == "ResNet50":
    # #        f=open(appfolder,'r')
    # #        data=json.load(f)
    # #        r=json.dumps(data)
    # #        pretrainmodel = K.models.model_from_json(r)
    # #        weightfolder=os.path.split(appfolder)
    # #        weightfolder=weightfolder[0]
    # #        if len(weights) == 0:
    # #            pretrainmodel.load_weights(os.path.join(weightfolder,"resnet50_weights.h5"))
    # #        else:
    # #            pretrainmodel.load_weights(weights)
    #         if len(weights) == 1:
    #             pretrainmodel = K.applications.ResNet50(include_top=False,input_tensor=input,weights="imagenet")
    #             print('get pretrainmodel')
    # #        else:
    # #            pretrainmodel = K.applications.ResNet50(weights=weights)
    # #            pretrainmodel = ResNet50(weights=weights)
    #     if modelname == "VGG16":
    #         if len(weights) == 1:
    #             pretrainmodel = K.applications.VGG16(weights="imagenet")
    # #        else:
    # #            import tensorflow.keras.applications.VGG16
    # #            pretrainmodel = VGG16(weights=weights)
    #     if modelname == "VGG19":
    #         if len(weights) == 1:
    #             pretrainmodel = K.applications.VGG19(weights="imagenet")
    # #        else:
    # #            import tensorflow.keras.applications.VGG19
    # #            pretrainmodel = VGG19(weights=weights)
    #     if modelname == "Xception":
    #         if len(weights) == 1:
    #             pretrainmodel = K.applications.Xception(weights="imagenet")
    # #        else:
    # #            import tensorflow.keras.applications.Xception
    # #            pretrainmodel = Xception(weights=weights)
    #     if modelname == "EfficientNetB0":
    #         if len(weights) == 1:
    #             pretrainmodel = K.applications.EfficientNetB0(weights="imagenet",drop_connect_rate=top_dropout_rate)
    #
    #     if modelname == "EfficientNetB1":
    #         if len(weights) == 1:
    #             pretrainmodel = K.applications.EfficientNetB1(weights="imagenet",drop_connect_rate=top_dropout_rate)
    #
    #     if modelname == "EfficientNetB2":
    #         if len(weights) == 1:
    #             pretrainmodel = K.applications.EfficientNetB2(weights="imagenet",drop_connect_rate=top_dropout_rate)
    #
    #     if modelname == "EfficientNetB3":
    #         if len(weights) == 1:
    #             pretrainmodel = K.applications.EfficientNetB3(weights="imagenet",drop_connect_rate=top_dropout_rate)
    #
    #     if modelname == "EfficientNetB4":
    #         if len(weights) == 1:
    #             pretrainmodel = K.applications.EfficientNetB4(weights="imagenet",drop_connect_rate=top_dropout_rate)
    #
    #     if modelname == "EfficientNetB5":
    #         if len(weights) == 1:
    #             pretrainmodel = K.applications.EfficientNetB5(weights="imagenet",drop_connect_rate=top_dropout_rate)
    #
    #     if modelname == "EfficientNetB6":
    #         if len(weights) == 1:
    #             pretrainmodel = K.applications.EfficientNetB6(weights="imagenet",drop_connect_rate=top_dropout_rate)
    #
    #     if modelname == "EfficientNetB7":
    #         if len(weights) == 1:
    #             pretrainmodel = K.applications.EfficientNetB7(weights="imagenet",drop_connect_rate=top_dropout_rate)
    # #    if modelname == "usermodel":
    # #        pretrainmodel=kerasmodel
    #     print(type('pretrainmodel type:'),pretrainmodel)
    #     if type(pretrainmodel)!=type(""):
    #         pretrainlastlayer=pretrainmodel.layers.pop(-1)
    #         print(pretrainmodel.summary())
    #         newlayer = kerasmodel.layers.pop(-1)
    #         print(pretrainlastlayer.output.shape[1],newlayer.output.shape[1])
    #         # print(pretrainmodel.layers[-1].output.shape[1])
    #         # print(kerasmodel.layers[-1].output.shape[1])
    #         # if pretrainlastlayer.output.shape[1]==newlayer.output.shape[1]:
    #         #     print('option 1')
    #         if tf_finetune == "true":
    #             if freezemodel == "true":
    #                 pretrainmodel.trainable=False
    #                 for l in pretrainmodel.layers[-20:]:
    #                     if not isinstance(l, K.layers.BatchNormalization):
    #                         l.trainable = True
    #             # input=kerasmodel.input
    #
    #
    #             x = K.layers.GlobalAveragePooling2D(name="avg_pool")(pretrainmodel.layers[-2].output)
    #             x = K.layers.BatchNormalization()(x)
    #
    #             x = K.layers.Dropout(top_dropout_rate)(x)
    #             NUM_CLASSES = int(predictclass)
    #             outputs = K.layers.Dense(NUM_CLASSES, activation="softmax", name="pred")(x)
    #             kerasmodel=K.models.Model(input,outputs)
    #             print(kerasmodel.summary())
    #
    #         else:
    #             # input = K.layers.Input(shape=(k_inputshape[1],k_inputshape[2],3))
    #             # input = kerasmodel.input
    #             out = newlayer(pretrainmodel.layers[-2].output)
    #             kerasmodel = K.models.Model(input,out)
    #             # kerasmodel=pretrainmodel
    #         # else:
    #         #     print('option 2')
    #
    # #            if len(weights)>0:
    # #                kerasmodel.load_weights(weights)
    #     else:
    #         kerasmodel.load_weights(weights)
    #
    #
    #         # if freezemodel=="true":
    #         #     kerasmodel.trainable=False
    #
    #         # kerasmodel=K.Model(kerasmodel,outputs,name="EfficientNet")
    train()
    validate()


if __name__ == '__main__':
    main()

